{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "08749765-5bf2-4305-a461-1d375fee41bb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import jieba\n",
    "import pandas as pd\n",
    "\n",
    "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n",
    "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n",
    "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "81a0c43c-3e04-42eb-9c1c-a284e0f21178",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>还有双鸭山到淮阴的汽车票吗13号的</td>\n",
       "      <td>Travel-Query</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>从这里怎么回家</td>\n",
       "      <td>Travel-Query</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>随便播放一首专辑阁楼里的佛里的歌</td>\n",
       "      <td>Music-Play</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>给看一下墓王之王嘛</td>\n",
       "      <td>FilmTele-Play</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>我想看挑战两把s686打突变团竞的游戏视频</td>\n",
       "      <td>Video-Play</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       0              1\n",
       "0      还有双鸭山到淮阴的汽车票吗13号的   Travel-Query\n",
       "1                从这里怎么回家   Travel-Query\n",
       "2       随便播放一首专辑阁楼里的佛里的歌     Music-Play\n",
       "3              给看一下墓王之王嘛  FilmTele-Play\n",
       "4  我想看挑战两把s686打突变团竞的游戏视频     Video-Play"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d8492bd1-ad05-4e3b-a778-dcecb2573196",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_data = train_data.sample(frac=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6f67d9cb-04e9-4ab7-8642-acb0c4f58493",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_data[1], lbl = pd.factorize(train_data[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "73d8e7e8-d520-43c9-bffa-a2d89c79a0e7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def coustom_data_iter(texts, labels):\n",
    "    for x, y in zip(texts, labels):\n",
    "        yield x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b40c62f9-04f7-4c58-8a28-3b2799e31fae",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fdd54feb-ac16-4b37-ad93-c571e6277e63",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.602 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "from torchtext.data.utils import get_tokenizer\n",
    "from torchtext.vocab import build_vocab_from_iterator\n",
    "\n",
    "tokenizer = jieba.lcut\n",
    "\n",
    "\n",
    "def yield_tokens(data_iter):\n",
    "    for text, _ in data_iter:\n",
    "        yield tokenizer(text)\n",
    "\n",
    "\n",
    "vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=[\"<unk>\"])\n",
    "vocab.set_default_index(vocab[\"<unk>\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4b036478-3f3f-48bf-b562-40bc295f6060",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<unk>', '的', '我', '一下', '播放', '是', '吗', '给', '帮', '一个']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab.get_itos()[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b6829460-2e9f-4630-bff9-4513e7c989e5",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2, 3, 41]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab(['我', '一下', '今天'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b1c03326-b060-46bd-a35d-71e3debb3075",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# from gensim.models import KeyedVectors\n",
    "\n",
    "# # 需要自行下载，然后修改路径后运行\n",
    "# wv_from_text = KeyedVectors.load_word2vec_format('/home/lyz/work/dataset/词向量/tencent-ailab-embedding-zh-d100-v0.2.0-s/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt', binary=False)\n",
    "\n",
    "# pretrained_w2v = []\n",
    "# for w in vocab.get_itos():\n",
    "#     if w in wv_from_text:\n",
    "#         pretrained_w2v.append(wv_from_text[w])\n",
    "#     else:\n",
    "#         pretrained_w2v.append(np.random.rand(100))\n",
    "        \n",
    "# pretrained_w2v = np.vstack(pretrained_w2v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "39b1f63f-36f8-45cb-bacd-69278e2dd790",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def text_pipeline(x): return vocab(tokenizer(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "399e7cc9-bd92-4860-8776-3d6ae991dbf4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "processed_text = torch.tensor(text_pipeline('今天我们在这里'), dtype=torch.int64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0d3373a6-4ffc-42ac-b0d7-c43dd3bc5e22",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.nn.utils.rnn import pad_sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "cce7e02f-ee05-48a0-9aa1-2bcfde8eef30",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device = 'cpu'\n",
    "\n",
    "# 90%的文本长度都在20个单词以内\n",
    "def collate_batch(batch, max_len=20):\n",
    "    label_list, text_list = [], []\n",
    "    for (_text, _label) in batch:\n",
    "        label_list.append(_label)\n",
    "        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)\n",
    "        processed_text = F.pad(processed_text, pad=[0, max_len,], mode='constant', value=0)\n",
    "        if len(processed_text) > max_len:\n",
    "            processed_text = processed_text[:max_len]\n",
    "\n",
    "        text_list.append(processed_text)\n",
    "    label_list = torch.tensor(label_list, dtype=torch.int64)\n",
    "    text_list = pad_sequence(text_list).T\n",
    "    return label_list.to(device), text_list.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c1f3ca37-4e2b-4ca7-a642-216a3b237a7c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class BILSTM(torch.nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, hidden_dim, label_size):\n",
    "        super(BILSTM, self).__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.embeddings = torch.nn.Embedding(vocab_size, embedding_dim)\n",
    "        # self.embeddings.weight.data.copy_(torch.from_numpy(pretrained_w2v))\n",
    "        \n",
    "        self.lstm = torch.nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, bidirectional=True)\n",
    "        self.hidden2label = torch.nn.Linear(hidden_dim*2, label_size)\n",
    "        \n",
    "\n",
    "    def forward(self, sentence):\n",
    "        # torch.Size([16, 20])\n",
    "        # print(sentence.shape)\n",
    "        sentence = torch.transpose(sentence, 1, 0)\n",
    "        \n",
    "        # torch.Size([20, 16])\n",
    "        # print(sentence.shape)\n",
    "        \n",
    "        # torch.Size([20, 16, 100])\n",
    "        x = self.embeddings(sentence)\n",
    "        # print(x.shape)\n",
    "        \n",
    "        # 5 seqence length\n",
    "        # 3 batch size\n",
    "        # 10 input size\n",
    "        # input = torch.randn(5, 3, 10)\n",
    "\n",
    "        lstm_out, self.hidden = self.lstm(x)\n",
    "        # torch.Size([20, 16, 128])\n",
    "        # print(lstm_out.shape)\n",
    "        y = self.hidden2label(lstm_out[-1,:,:])\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d09e44a3-c003-4909-838f-50e533bff31d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 3, 2])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.transpose(torch.rand((2,3,1)), 0, 2).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "bf4c839f-24e5-49f8-b7fa-1f6e5eb17722",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def train(dataloader):\n",
    "    model.train()\n",
    "    total_acc, total_count = 0, 0\n",
    "\n",
    "    for idx, (label, text) in enumerate(dataloader):\n",
    "        optimizer.zero_grad()\n",
    "        predicted_label = model(text)\n",
    "        loss = criterion(predicted_label, label)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n",
    "        optimizer.step()\n",
    "        total_acc += (predicted_label.argmax(1) == label).sum().item()\n",
    "        total_count += label.size(0)\n",
    "\n",
    "def evaluate(dataloader):\n",
    "    model.eval()\n",
    "    total_acc, total_count = 0, 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for idx, (label, text) in enumerate(dataloader):\n",
    "            predicted_label = model(text)\n",
    "            loss = criterion(predicted_label, label)\n",
    "            total_acc += (predicted_label.argmax(1) == label).sum().item()\n",
    "            total_count += label.size(0)\n",
    "    return total_acc/total_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "322569ce-b658-478c-a881-86eaaff6af90",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "num_class = len(lbl)\n",
    "vocab_size = len(vocab)\n",
    "emsize = 100\n",
    "model = BILSTM(vocab_size, emsize, 64, num_class).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c66d6645-5ef9-4ecd-af59-9bf9e572e3ee",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.utils.data.dataset import random_split\n",
    "from torchtext.data.functional import to_map_style_dataset\n",
    "# Hyperparameters\n",
    "EPOCHS = 40  # epoch\n",
    "LR = 2 # learning rate\n",
    "BATCH_SIZE = 16  # batch size for training\n",
    "\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=LR)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5.0, gamma=0.75)\n",
    "total_accu = None\n",
    "\n",
    "train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])\n",
    "train_dataset = to_map_style_dataset(train_iter)\n",
    "\n",
    "num_train = int(len(train_dataset) * 0.75)\n",
    "split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])\n",
    "\n",
    "train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,\n",
    "                              shuffle=True, collate_fn=collate_batch)\n",
    "valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,\n",
    "                              shuffle=True, collate_fn=collate_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "9f60cb58-a65c-420b-a487-5bfb9a14f7db",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 12])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for (label, text) in train_dataloader:\n",
    "    break\n",
    "    \n",
    "model(text).shape\n",
    "# 12, 16, 100\n",
    "# sequence length * batch size * embedding dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9af00bc2-9949-4f85-a785-47a227a21569",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for epoch in range(1, EPOCHS + 1):\n",
    "    epoch_start_time = time.time()\n",
    "    train(train_dataloader)\n",
    "    accu_val = evaluate(valid_dataloader)\n",
    "    if total_accu is not None and total_accu > accu_val:\n",
    "        scheduler.step()\n",
    "    else:\n",
    "        total_accu = accu_val\n",
    "    \n",
    "    print('| end of epoch {:3d} | time: {:5.2f}s | '\n",
    "          'valid accuracy {:8.3f} '.format(epoch,\n",
    "                                           time.time() - epoch_start_time,\n",
    "                                           accu_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50ef9d28-fd43-41a7-a7b0-35171878b5d9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8980e9ae-fb3f-4a16-8792-565cbd1e4f77",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_iter = coustom_data_iter(test_data[0].values[:], [0] * len(test_data))\n",
    "test_dataset = to_map_style_dataset(test_iter)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,\n",
    "                             shuffle=False, collate_fn=collate_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33b9132d-cbdf-4c91-bea8-6cac6b737471",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def predict(dataloader):\n",
    "    model.eval()\n",
    "\n",
    "    test_pred = []\n",
    "    with torch.no_grad():\n",
    "        for idx, (label, text) in enumerate(dataloader):\n",
    "            predicted_label = model(text).argmax(1)\n",
    "            test_pred += list(predicted_label.cpu().numpy())\n",
    "    return test_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7d1273c-9a20-43bf-83aa-a3c5b56881bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_pred = predict(test_dataloader)\n",
    "test_pred = [lbl[x] for x in test_pred]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44907273-d712-47bb-985b-bde269132bea",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pd.DataFrame({\n",
    "    'ID': range(1, len(test_pred) + 1),\n",
    "    'Target': test_pred,\n",
    "}).to_csv('nlp_submit.csv', index=None)\n",
    "\n",
    "# 提交一下吧~"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15046e8c-1e7b-4b99-83d4-0e4211217f29",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
